from typing import Tuple, List

import jax.numpy as jnp
import jax
from functools import partial

from common import Batch, InfoDict, Model, Params, PRNGKey

import policy


def chi_square_loss(diff, alpha, args=None):
    loss = alpha*jnp.maximum(diff+diff**2/4,0) - (1-alpha)*diff
    return loss

def total_variation_loss(diff, alpha, args=None):
    loss = alpha*jnp.maximum(diff,0) - (1-alpha)*diff
    return loss

def recoil_loss(diff, alpha, args=None):
    loss = jnp.minimum(jnp.exp(alpha * diff), 100) + alpha*jnp.maximum(diff,0)
    return loss

def reverse_kl_loss(diff, alpha, args=None):
    """ Gumbel loss J: E[e^x - x - 1]. For stability to outliers, we scale the gradients with the max value over a batch
    and optionally clip the exponent. This has the effect of training with an adaptive lr.
    """
    z = diff/alpha
    if args.max_clip is not None:
        z = jnp.minimum(z, args.max_clip) # clip max value
    max_z = jnp.max(z, axis=0)
    max_z = jnp.where(max_z < -1.0, -1.0, max_z)
    max_z = jax.lax.stop_gradient(max_z)  # Detach the gradients
    loss = jnp.exp(z - max_z) - z*jnp.exp(-max_z) - jnp.exp(-max_z)  # scale by e^max_z
    return loss

def expectile_loss(diff, expectile=0.8):
    weight = jnp.where(diff > 0, expectile, (1 - expectile))
    return weight * (diff**2)
 
def pagar_update_v(reward: Model, value: Model, batch: Batch, discount: float, 
             expectile: float, loss_temp: float, double: bool, vanilla: bool, key: PRNGKey, args) -> Tuple[Model, InfoDict]:
    actions = batch.actions
    
    rng1, rng2 = jax.random.split(key)
    if args.sample_random_times > 0:
        # add random actions to smooth loss computation (use 1/2(rho + Unif))
        times = args.sample_random_times
        random_action = jax.random.uniform(
            rng1, shape=(times * actions.shape[0],
                         actions.shape[1]),
            minval=-1.0, maxval=1.0)
        obs = jnp.concatenate([batch.observations, jnp.repeat(
            batch.observations, times, axis=0)], axis=0)
        acts = jnp.concatenate([batch.actions, random_action], axis=0)
        nxt_obs = jnp.concatenate([batch.next_observations, jnp.repeat(
            batch.next_observations, times, axis=0)], axis=0)
        masks = jnp.concatenate([batch.masks, jnp.repeat(
            batch.masks, times, axis=0)], axis=0)
    else:
        obs = batch.observations
        acts = batch.actions
        nxt_obs = batch.next_observations
        masks = batch.masks

    if args.noise:
        std = args.noise_std
        noise = jax.random.normal(rng2, shape=(acts.shape[0], acts.shape[1]))
        noise = jnp.clip(noise * std, -0.5, 0.5)
        acts = (batch.actions + noise)
        acts = jnp.clip(acts, -1, 1)

    ## s,a,s'>s,a,zero >>>>>>> s,a,s (almost not working)
    r1, r2 = reward(obs, acts, nxt_obs * masks.reshape(-1, 1))# + obs * (1 - masks.reshape(-1, 1)))
    if double:
        r = jnp.minimum(r1, r2)
    else:
        r = r1

    def value_loss_fn(value_params: Params) -> Tuple[jnp.ndarray, InfoDict]:
        v = value.apply({'params': value_params}, obs)
        nxt_v = value.apply({'params': value_params}, nxt_obs)
        q = r + discount * masks * nxt_v

        fs = args.f.split('+')
        value_loss = 0
        for f in fs:
            if False and f == 'recoil':
                value_loss += recoil_loss(q - v, alpha=loss_temp, args=args).mean()
            elif f=='chi-square':
                value_loss += chi_square_loss(q - v, alpha=loss_temp, args=args).mean()
            elif f=='total-variation':
                value_loss += total_variation_loss(q - v, alpha=loss_temp, args=args).mean()
            elif f == 'recoil' or f=='reverse-kl': # Same as XQL
                value_loss += reverse_kl_loss(q - v, alpha=loss_temp, args=args).mean()
 
        return value_loss, {
            'value_loss': value_loss,
            'v': v.mean(),
        }
  
    new_value, info = value.apply_gradient(value_loss_fn)

    return new_value, info


def pagar_update_r(reward: Model, target_value: Model, protagonist_actor: Model, antagonist_actor: Model, expert_batch: Batch, suboptimal_batch: Batch, mix_batch: Batch, 
             discount: float, double: bool, key: PRNGKey, loss_temp: float, temperature: float, args) -> Tuple[Model, InfoDict]:
    
    def reward_loss_fn(reward_params: Params) -> Tuple[jnp.ndarray, InfoDict]:
        mix_next_v = target_value(mix_batch.next_observations)
        mix_target_v = discount * mix_batch.masks * mix_next_v

        mix_acts = mix_batch.actions
        mix_r1, mix_r2 = reward.apply({'params': reward_params}, mix_batch.observations, mix_acts, mix_batch.next_observations * mix_batch.masks.reshape(-1, 1))# + mix_batch.observations * (1 - mix_batch.masks.reshape(-1, 1)))
        mix_v = target_value(mix_batch.observations)

        def mse_loss(q, q_target, *args):
            loss_dict = {}

            x = q-q_target
            loss = huber_loss(x, delta=20.0)  # x**2
            loss_dict['mse_loss'] = loss.mean()

            return loss.mean(), loss_dict
 

        if double:
            loss1, dict1 = mse_loss(mix_r1, mix_v - mix_target_v, mix_v, loss_temp)
            loss2, dict2 = mse_loss(mix_r2, mix_v - mix_target_v, mix_v, loss_temp)

            reward_loss = (loss1 + loss2).mean()
            for k, v in dict2.items():
                dict1['reward_critic_' + k] = dict1[k] + v
            loss_dict = dict1
        else:
            # critic_loss, loss_dict = dual_q_loss(q1, target_q, v, loss_temp)
            reward_loss, loss_dict = mse_loss(mix_r1, mix_v - mix_target_v,  v, loss_temp)
        if 'recoil' in args.f:
            # Use the config in the paper: 1 million transitions sub-optimal trajs vs 30 expert demonstrations ==> beta = 30/1000000
            expert_acts = expert_batch.actions
            expert_r1, expert_r2 = reward.apply({'params': reward_params}, 
                                                expert_batch.observations, 
                                                expert_acts, 
                                                expert_batch.next_observations * expert_batch.masks.reshape(-1, 1))# + expert_batch.observations * (1 - expert_batch.masks.reshape(-1, 1)))
            
            expert_q1 = expert_r1 + discount * expert_batch.masks * target_value(expert_batch.next_observations)
            expert_q2 = expert_r2 + discount * expert_batch.masks * target_value(expert_batch.next_observations)

            expert_loss1, expert_dict1 = mse_loss(expert_q1, 200)
            expert_loss2, expert_dict2 = mse_loss(expert_q2, 200)
            
            expert_loss = (expert_loss1 + expert_loss2).mean()

            expert_dict = {('reward_recoil_expert_' + k): v for k, v in expert_dict1.items()}
            for k, v in expert_dict2.items():
                expert_dict['reward_recoil_expert_' + k] += v

            loss_dict.update(expert_dict)

            suboptimal_acts = suboptimal_batch.actions
            suboptimal_r1, suboptimal_r2 = reward.apply({'params': reward_params}, 
                                                        suboptimal_batch.observations, 
                                                        suboptimal_acts, 
                                                        suboptimal_batch.next_observations * suboptimal_batch.masks.reshape(-1, 1))# + suboptimal_batch.observations * (1 - suboptimal_batch.masks.reshape(-1, 1)))
            
            suboptimal_q1 = suboptimal_r1 + discount * suboptimal_batch.masks *  target_value(suboptimal_batch.next_observations)
            suboptimal_q2 = suboptimal_r2 + discount * suboptimal_batch.masks * target_value(suboptimal_batch.next_observations)

            suboptimal_loss = jnp.maximum(suboptimal_q1, suboptimal_q2).mean() 
            loss_dict['reward_recoil_suboptimal_loss'] = suboptimal_loss

            recoil_loss = 4 * 0.9 *  (suboptimal_loss + expert_loss)
            reward_loss += recoil_loss
        

            ############ PAGAR ############
            rng = key
            
            rng, protagonist_expert_log_probs = policy.log_probs(rng, protagonist_actor.apply_fn,
                                    protagonist_actor.params, expert_batch.observations, expert_batch.actions,
                                        temperature)
            rng, antagonist_expert_log_probs = policy.log_probs(rng, antagonist_actor.apply_fn,
                                        antagonist_actor.params, expert_batch.observations, expert_batch.actions,
                                            temperature)
            delta_expert_log_probs = jnp.maximum(antagonist_expert_log_probs - protagonist_expert_log_probs, 0.)
            expert_r = jnp.minimum(expert_r1, expert_r2)
            #expert_q = jnp.minimum(expert_q1, expert_q2)
            #expert_v = target_value(expert_batch.observations)
            #expert_a = expert_q - expert_v
            #expert_exp_a = jnp.exp(expert_a * temperature)
            #expert_exp_a = jnp.minimum(expert_exp_a, 100.0)
            pagar_loss = -(expert_r * jnp.exp(jnp.clip(delta_expert_log_probs, -1., 1))).mean()

            
            rng, protagonist_suboptimal_log_probs = policy.log_probs(rng, protagonist_actor.apply_fn,
                                    protagonist_actor.params, suboptimal_batch.observations, suboptimal_batch.actions,
                                        temperature)
            rng, antagonist_suboptimal_log_probs = policy.log_probs(rng, antagonist_actor.apply_fn,
                                        antagonist_actor.params, suboptimal_batch.observations, suboptimal_batch.actions,
                                            temperature)
            delta_suboptimal_log_probs = jnp.minimum(antagonist_suboptimal_log_probs - protagonist_suboptimal_log_probs, 0.)
            suboptimal_r = jnp.minimum(suboptimal_r1, suboptimal_r2)
            #suboptimal_q = jnp.minimum(suboptimal_q1, suboptimal_q2)
            #suboptimal_v = target_value(suboptimal_batch.observations)
            #suboptimal_a = suboptimal_q - suboptimal_v
            #suboptimal_exp_a = jnp.exp(suboptimal_a * temperature)
            #suboptimal_exp_a = jnp.minimum(suboptimal_exp_a, 100.0)
            pagar_loss += -(suboptimal_r * jnp.exp(jnp.clip(delta_suboptimal_log_probs, -1., 1))).mean()

 
            loss_dict['pagar_reward_loss'] = pagar_loss

            reward_loss += 1e-3 * pagar_loss
 
        if args.grad_pen:
            lambda_ = args.lambda_gp
            r1_grad, r2_grad = grad_norm(reward, reward_params, mix_batch.observations, mix_acts, mix_batch.next_observations * mix_batch.masks.reshape(-1, 1))# + mix_batch.observations * (1 - mix_batch.masks.reshape(-1, 1)))
            loss_dict['r1_grad'] = r1_grad.mean()
            loss_dict['r2_grad'] = r2_grad.mean()

            if double:
                gp_loss = (r1_grad + r2_grad).mean()
            else:
                gp_loss = r1_grad.mean()

            reward_loss += lambda_ * gp_loss

        #loss_dict.update({
        #    'r1': mix_r1.mean(),
        #    'r2': mix_r2.mean()
        #})

        return reward_loss, loss_dict

    new_reward, info = reward.apply_gradient(reward_loss_fn)

    return new_reward, info


def grad_norm(model, params, obs, action, lambda_=10):

    @partial(jax.vmap, in_axes=(0, 0))
    @partial(jax.jacrev, argnums=1)
    def input_grad_fn(obs, action):
        return model.apply({'params': params}, obs, action)

    def grad_pen_fn(grad):
        # We use gradient penalties inspired from WGAN-LP loss which penalizes grad_norm > 1
        penalty = jnp.maximum(jnp.linalg.norm(grad1, axis=-1) - 1, 0)**2
        return penalty

    grad1, grad2 = input_grad_fn(obs, action)

    return grad_pen_fn(grad1), grad_pen_fn(grad2)


def huber_loss(x, delta: float = 1.):
    """Huber loss, similar to L2 loss close to zero, L1 loss away from zero.
    See "Robust Estimation of a Location Parameter" by Huber.
    (https://projecteuclid.org/download/pdf_1/euclid.aoms/1177703732).
    Args:
    x: a vector of arbitrary shape.
    delta: the bounds for the huber loss transformation, defaults at 1.
    Note `grad(huber_loss(x))` is equivalent to `grad(0.5 * clip_gradient(x)**2)`.
    Returns:
    a vector of same shape of `x`.
    """
    # 0.5 * x^2                  if |x| <= d
    # 0.5 * d^2 + d * (|x| - d)  if |x| > d
    abs_x = jnp.abs(x)
    quadratic = jnp.minimum(abs_x, delta)
    # Same as max(abs_x - delta, 0) but avoids potentially doubling gradient.
    linear = abs_x - quadratic
    return 0.5 * quadratic**2 + delta * linear
